In this notebook, we will perform an EDA (Exploratory Data Analysis) on the processed Waymo dataset (data in the processed folder). In the first part, you will create a function to display
from utils import get_dataset
import glob
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import tensorflow as tf
%matplotlib inline
dataset = get_dataset('/home/workspace/data/train/*.tfrecord')
dataset
for i, element in enumerate(dataset):
print(element)
if i >= 2:
break
Implement the display_images function below. This function takes a batch as an input and display an image with its corresponding bounding boxes. The only requirement is that the classes should be color coded (eg, vehicles in red, pedestrians in blue, cyclist in green).
def display_images(batch):
'''displays image in batch along with correctly color-coded bounding boxes
'''
bbox_colors = {
1: 'red', # cars
2: 'blue', # pedestrians
4: 'green' # cyclists
}
fig, ax = plt.subplots(figsize=(15, 15))
image = batch['image'].numpy()
ground_truth_boxes = batch['groundtruth_boxes'].numpy()
ground_truth_labels = batch['groundtruth_classes'].numpy()
for bbox, label in zip(ground_truth_boxes, ground_truth_labels):
anchor = (bbox[1] * image.shape[1], bbox[0] * image.shape[0])
width = (bbox[3] - bbox[1]) * image.shape[1]
height = (bbox[2] - bbox[0]) * image.shape[0]
box_patch = patches.Rectangle(anchor,
width,
height,
linewidth=1,
facecolor='None',
edgecolor=bbox_colors[label]
)
ax.add_patch(box_patch)
# remove tick labels
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(image)
Using the dataset created in the second cell and the function you just coded, display 10 random images with the associated bounding boxes. You can use the methods take and shuffle on the dataset.
# Display 10 random images in dataset
for batch in dataset.shuffle(50).take(10):
display_images(batch)
In this last part, you are free to perform any additional analysis of the dataset. What else would like to know about the data? For example, think about data distribution. So far, you have only looked at a single file...
# How many objects of each type over all images?
object_count_dict = {
1: 0, # cars
2: 0, # pedestrians
4: 0 # cyclists
}
object_names_dict = {
1: 'cars',
2: 'pedestrians',
4: 'cyclists'
}
# How many images have which class(es) in them?
classes_count_dict = {
0: 0, # no objects
1: 0, # just cars
2: 0, # just pedestrians
3: 0, # no cyclists
4: 0, # just cyclists
5: 0, # no pedestrians
6: 0, # no cars
7: 0, # all of the three
}
classes_names_dict = {
0: 'no objects',
1: 'just cars',
2: 'just pedestrians',
3: 'no cyclists',
4: 'just cyclists',
5: 'no pedestrians',
6: 'no cars',
7: 'all three'
}
sample_size = 50000
for batch in dataset.take(sample_size):
this_count = {1: 0, 2: 0, 4: 0}
classes = 0
for label in batch['groundtruth_classes'].numpy():
this_count[label] += 1
object_count_dict[label] += 1
for key, val in this_count.items():
if val:
classes += key
classes_count_dict[classes] += 1
print(classes_count_dict)
bars = list(classes_names_dict.values())
plt.bar(bars, classes_count_dict.values())
plt.xticks(bars, bars, rotation=90)
plt.xlabel('objects in image')
plt.ylabel('number of images')
plt.title(f'Distribution of objects in {sample_size} images out of the dataset')
plt.show()
bars = list(object_names_dict.values())
plt.bar(bars, object_count_dict.values())
plt.xticks(bars, bars, rotation=90)
plt.xlabel('objects in image')
plt.ylabel('number of images')
plt.title(f'Number of total objects in {sample_size} images out of the dataset')
plt.show()